Exploring the MNIST Digits Dataset

Introduction

The MNIST digits dataset is a famous dataset of handwritten digit images. You can read more about it at wikipedia or Yann LeCun's page. It's a useful dataset because it provides an example of a pretty simple, straightforward image processing task, for which we know exactly what state of the art accuracy is.

I plan to use this dataset for a couple upcoming machine learning blog posts, and since the first step of pretty much any ML task is 'explore your data,' I figured I would post this first, to have to refer back to, instead of repeating in each subsequent post.

Conveniently, scikit-learn has a built-in utility for loading this (and other) standard datsets.

Loading the Digits Dataset


In [22]:
import pandas as pd
import matplotlib.pyplot as plt
import os

In [23]:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original', data_home='datasets/')

In [24]:
# Convert sklearn 'datasets bunch' object to Pandas DataFrames
y = pd.Series(mnist.target).astype('int').astype('category')
X = pd.DataFrame(mnist.data)

Data Shape, Summary Stats

We can see below that our data (X) and target (y) have 70,000 rows, meaning we have information on 70,000 digit images. Our X, or independent variables dataset, has 784 columns, which correspond to the 784 pixel values in a 28-pixel x 28-pixel image (28x28 = 784). Our y, or target, is a single column representing the true digit labels (0-9) for each image.


In [18]:
X.shape, y.shape


Out[18]:
((70000, 784), (70000,))

In [46]:
# Change column-names in X to reflect that they are pixel values
num_images = X.shape[1]
X.columns = ['pixel_'+str(x) for x in range(num_images)]

# print first row of X
X.head(1)


Out[46]:
pixel_0 pixel_1 pixel_2 pixel_3 pixel_4 pixel_5 pixel_6 pixel_7 pixel_8 pixel_9 ... pixel_774 pixel_775 pixel_776 pixel_777 pixel_778 pixel_779 pixel_780 pixel_781 pixel_782 pixel_783
0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

1 rows × 784 columns

Below we see min, max, mean and most-common pixel-intensity values for our rows/images. As suggested by the first row above, our most common value is 0. In fact even the median is 0, which means over half of our pixels are background/blank space. Makes sense.


In [37]:
X_values = pd.Series(X.values.ravel())
print(" min: {}, \n max: {}, \n mean: {}, \n median: {}, \n most common value: {}".format(X_values.min(), 
                                                                                          X_values.max(), 
                                                                                          X_values.mean(),
                                                                                          X_values.median(), 
                                                                                          X_values.value_counts().idxmax()))


 min: 0, 
 max: 255, 
 mean: 33.385964741253645, 
 median: 0.0, 
 most common value: 0

We might wonder if there are only a few distinct pixel values present in the data (e.g. black, white, and a few shades of grey), but in fact we have all 256 values between our min-max of 0-255:


In [47]:
len(np.unique(X.values))


Out[47]:
256

Viewing the Digit Images

We can also take a look at the digits images themselves, with matplotlib's handy function pyplot.imshow().

imshow accepts a dataset to plot, which it will interpret as pixel values. It also accepts a color-mapping to determine the color each pixel-value should be displayed as.

In the code below, we'll plot our first row/image, using the "reverse grayscale" color-map, to plot 0 (background in this dataset) as white.


In [39]:
# First row is first image
first_image = X.loc[0,:]
first_label = y[0]

# 784 columns correspond to 28x28 image
plottable_image = np.reshape(first_image.values, (28, 28))

# Plot the image
plt.imshow(plottable_image, cmap='gray_r')
plt.title('Digit Label: {}'.format(first_label))
plt.show()


And here's a few more...


In [44]:
images_to_plot = 9
random_indices = random.sample(range(70000), images_to_plot)

sample_images = X.loc[random_indices, :]
sample_labels = y.loc[random_indices]

In [45]:
plt.clf()
plt.style.use('seaborn-muted')

fig, axes = plt.subplots(3,3, 
                         figsize=(5,5),
                         sharex=True, sharey=True,
                         subplot_kw=dict(adjustable='box-forced', aspect='equal')) #https://stackoverflow.com/q/44703433/1870832

for i in range(images_to_plot):
    
    # axes (subplot) objects are stored in 2d array, accessed with axes[row,col]
    subplot_row = i//3 
    subplot_col = i%3  
    ax = axes[subplot_row, subplot_col]

    # plot image on subplot
    plottable_image = np.reshape(sample_images.iloc[i,:].values, (28,28))
    ax.imshow(plottable_image, cmap='gray_r')
    
    ax.set_title('Digit Label: {}'.format(sample_labels.iloc[i]))
    ax.set_xbound([0,28])

plt.tight_layout()
plt.show()


<matplotlib.figure.Figure at 0x7fdb381655c0>

Final Wrap-up

One last thing I'd want to check here before moving forward with any classification task, would be to determine how balanced our dataset is. Do we have a pretty even distribution of each digit? Or do we have mostly 7s, for example?


In [6]:
y.value_counts(normalize=True)


Out[6]:
1    0.112529
7    0.104186
3    0.102014
2    0.099857
9    0.099400
0    0.098614
6    0.098229
8    0.097500
4    0.097486
5    0.090186
dtype: float64

Great. Almost all the digits appear in between 9.7% and 10.4% of our rows. 11.3% of our digits are 1, which is the most common, and 5 is the least common at 9.0%. Overall a very well-balanced dataset.